{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Binding: Attach runtime args\n",
    "\n",
    "Sometimes we want to invoke a Runnable within a Runnable sequence with constant arguments that are not part of the output of the preceding Runnable in the sequence, and which are not part of the user input. We can use Runnable.bind() to pass these arguments in.<br>\n",
    "有时，我们希望在 Runnable 序列中调用 Runnable 并带有常量参数，这些参数不是序列中前一个 Runnable 输出的一部分，也不是用户输入的一部分。我们可以用来 Runnable.bind() 传递这些参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EQUATION:\n",
      " \\(x^3 + 7 = 12\\)\n",
      "\n",
      "SOLUTION:\n",
      "Subtract 7 from both sides:\n",
      "\\[x^3 = 5\\]\n",
      "\n",
      "Take the cube root of both sides:\n",
      "\\[x = \\sqrt[3]{5}\\]\n",
      "\n",
      "Therefore, the solution is \\(x = \\sqrt[3]{5}\\).\n"
     ]
    }
   ],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"Write out the following equation using algebraic symbols then solve it. Use the format\\n\\nEQUATION:...\\nSOLUTION:...\\n\\n\",\n",
    "        ),\n",
    "        (\"human\", \"{equation_statement}\"),\n",
    "    ]\n",
    ")\n",
    "model = ChatOpenAI(temperature=0)\n",
    "runnable = (\n",
    "    {\"equation_statement\": RunnablePassthrough()} | prompt | model | StrOutputParser()\n",
    ")\n",
    "\n",
    "print(runnable.invoke(\"x raised to the third plus seven equals 12\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EQUATION: x^3 + 7 = 12\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "runnable = (\n",
    "    {\"equation_statement\": RunnablePassthrough()}\n",
    "    | prompt\n",
    "    | model.bind(stop=\"SOLUTION\")\n",
    "    | StrOutputParser()\n",
    ")\n",
    "print(runnable.invoke(\"x raised to the third plus seven equals 12\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attaching OpenAI functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{\\n\"equation\": \"x^3 + 7 = 12\",\\n\"solution\": \"x = ∛5\"\\n}', 'name': 'solver'}}, response_metadata={'token_usage': {'completion_tokens': 33, 'prompt_tokens': 89, 'total_tokens': 122}, 'model_name': 'gpt-4-0613', 'system_fingerprint': None, 'finish_reason': 'function_call', 'logprobs': None}, id='run-b1730439-dcb5-4830-aa6b-4fe30ce1b610-0', usage_metadata={'input_tokens': 89, 'output_tokens': 33, 'total_tokens': 122})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "function = {\n",
    "    \"name\": \"solver\",\n",
    "    \"description\": \"Formulates and solves an equation\",\n",
    "    \"parameters\": {\n",
    "        \"type\": \"object\",\n",
    "        \"properties\": {\n",
    "            \"equation\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The algebraic expression of the equation\",\n",
    "            },\n",
    "            \"solution\": {\n",
    "                \"type\": \"string\",\n",
    "                \"description\": \"The solution to the equation\",\n",
    "            },\n",
    "        },\n",
    "        \"required\": [\"equation\", \"solution\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "# Need gpt-4 to solve this one correctly\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\n",
    "            \"system\",\n",
    "            \"Write out the following equation using algebraic symbols then solve it.\",\n",
    "        ),\n",
    "        (\"human\", \"{equation_statement}\"),\n",
    "    ]\n",
    ")\n",
    "model = ChatOpenAI(model=\"gpt-4\", temperature=0).bind(\n",
    "    function_call={\"name\": \"solver\"}, functions=[function]\n",
    ")\n",
    "runnable = {\"equation_statement\": RunnablePassthrough()} | prompt | model\n",
    "runnable.invoke(\"x raised to the third plus seven equals 12\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attaching OpenAI tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_ct87786a4C9msfovYYNaFimp', 'function': {'arguments': '{\"location\": \"San Francisco, CA\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}, {'id': 'call_8nJP2ZNwA2tB8iYeMm3PsXab', 'function': {'arguments': '{\"location\": \"New York, NY\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}, {'id': 'call_9CMoYqUsA309gzkgGxjeSmLk', 'function': {'arguments': '{\"location\": \"Los Angeles, CA\", \"unit\": \"celsius\"}', 'name': 'get_current_weather'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 84, 'prompt_tokens': 85, 'total_tokens': 169}, 'model_name': 'gpt-3.5-turbo-1106', 'system_fingerprint': 'fp_b81a85d4e3', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-8cf242ff-58e2-4ada-a86c-492a2e75ad54-0', tool_calls=[{'name': 'get_current_weather', 'args': {'location': 'San Francisco, CA', 'unit': 'celsius'}, 'id': 'call_ct87786a4C9msfovYYNaFimp', 'type': 'tool_call'}, {'name': 'get_current_weather', 'args': {'location': 'New York, NY', 'unit': 'celsius'}, 'id': 'call_8nJP2ZNwA2tB8iYeMm3PsXab', 'type': 'tool_call'}, {'name': 'get_current_weather', 'args': {'location': 'Los Angeles, CA', 'unit': 'celsius'}, 'id': 'call_9CMoYqUsA309gzkgGxjeSmLk', 'type': 'tool_call'}], usage_metadata={'input_tokens': 85, 'output_tokens': 84, 'total_tokens': 169})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_current_weather\",\n",
    "            \"description\": \"Get the current weather in a given location\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"location\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
    "                    },\n",
    "                    \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
    "                },\n",
    "                \"required\": [\"location\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "]\n",
    "\n",
    "model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\").bind(tools=tools)\n",
    "model.invoke(\"What's the weather in SF, NYC and LA?\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "langchain0_1",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
